import numpy as np
import time
import matplotlib.pyplot as plt

from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    classification_report,
    confusion_matrix,
    roc_curve,
    precision_recall_curve
)
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline

pseudo_data = np.load(
"Path to file pseudolabel_results.npz"
)

X_train = pseudo_data["combined_features"]
y_train = pseudo_data["combined_labels"]

val_data = np.load(
"Path to extracted test features using dinov2 test.npz"
)

X_val = val_data["features"]
y_val = val_data["labels"]

print("Train:", X_train.shape)
print("Validation:", X_val.shape)


model_names = []
f1_scores = []
roc_curves = []
pr_curves = []
sens_list = []
spec_list = []

def evaluate_model(model, name):

    print("Training", name)

    start = time.time()
    model.fit(X_train, y_train)
    train_time = time.time() - start

    start = time.time()
    y_pred = model.predict(X_val)
    test_time = time.time() - start

    if hasattr(model, "predict_proba"):
        y_prob = model.predict_proba(X_val)[:,1]
    else:
        y_prob = None

    acc = accuracy_score(y_val, y_pred)
    prec = precision_score(y_val, y_pred)
    rec = recall_score(y_val, y_pred)
    f1 = f1_score(y_val, y_pred)

    tn, fp, fn, tp = confusion_matrix(y_val, y_pred).ravel()

    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)

    print("Accuracy      :", format(acc,'.4f'))
    print("Precision     :", format(prec,'.4f'))
    print("Recall        :", format(rec,'.4f'))
    print("F1 Score      :", format(f1,'.4f'))
    print("Sensitivity   :", format(sensitivity,'.4f'))
    print("Specificity   :", format(specificity,'.4f'))

    if y_prob is not None:

        auc = roc_auc_score(y_val, y_prob)
        print("AUC           :", format(auc,'.4f'))

        fpr,tpr,_ = roc_curve(y_val,y_prob)
        precision_curve,recall_curve,_ = precision_recall_curve(y_val,y_prob)

        roc_curves.append((name,fpr,tpr))
        pr_curves.append((name,recall_curve,precision_curve))

    print("Train Time    :", round(train_time,2),"s")
    print("Test Time     :", round(test_time,2),"s")

    print("\nClassification Report")
    print(classification_report(y_val,y_pred,digits=4))

    print("Confusion Matrix")
    print(confusion_matrix(y_val,y_pred))

    model_names.append(name)
    f1_scores.append(f1)
    sens_list.append(sensitivity)
    spec_list.append(specificity)

svm_model = Pipeline([
("scaler",StandardScaler()),
("svm",SVC(kernel="linear",probability=True,random_state=42))
])

evaluate_model(svm_model,"Linear SVM")


rf_model = RandomForestClassifier(
n_estimators=100,
random_state=42,
n_jobs=-1
)

evaluate_model(rf_model,"Random Forest")


if XGB_AVAILABLE:

    xgb_model = xgb.XGBClassifier(
    n_estimators=100,
    max_depth=6,
    learning_rate=0.1,
    eval_metric="logloss",
    random_state=42
    )

    evaluate_model(xgb_model,"XGBoost")

plt.figure()

plt.bar(model_names,f1_scores)

plt.title("Model Comparison (F1 Score)")
plt.xlabel("Model")
plt.ylabel("F1 Score")

for i,v in enumerate(f1_scores):
    plt.text(i,v+0.01,f"{v:.4f}",ha="center")

plt.show()

plt.figure()

for name,fpr,tpr in roc_curves:
    plt.plot(fpr,tpr,label=name)

plt.plot([0,1],[0,1],'--')

plt.title("ROC Curve")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")

plt.legend()

plt.show()

plt.figure()

for name,recall,precision in pr_curves:
    plt.plot(recall,precision,label=name)

plt.title("Precision Recall Curve")
plt.xlabel("Recall")
plt.ylabel("Precision")

plt.legend()

plt.show()

plt.figure()

plt.scatter(spec_list,sens_list)




